Machine learning to segment neutron images

Anders Kaestner, Beamline scientist - Neutron Imaging

Laboratory for Neutron Scattering and Imaging
Paul Scherrer Institut

Lecture outline

  1. Introduction
  2. Limited data problem
  3. Unsupervised segmentation
  4. Supervised segmentation
  5. Final problem: Segmenting root networks using convolutional NNs
  6. Future Machine learning challenges in NI

Getting started

If you want to run the notebook on your own computer, you'll need to perform the following step:

  • You will need to install Anaconda
  • Clone the lecture repository (in the location you'd like to have it)
    git clone https://github.com/ImagingLectures/MLSegmentation4NI.git
    
  • Enter the folder 'MLSegmentation'
  • Create an environment for the notebook
    conda env create -f environment. yml -n MLSeg4NI
    
  • Enter the environment
    conda env activate MLSeg4NI
    

Importing needed modules

This lecture needs some modules to run. We import all of them here.

In [1]:
import matplotlib.pyplot as plt
import seaborn           as sn
import numpy             as np
import pandas            as pd
import skimage.filters   as flt
import skimage.io        as io
import matplotlib        as mpl

from sklearn.cluster     import KMeans
from sklearn.neighbors   import KNeighborsClassifier
from sklearn.metrics     import confusion_matrix
from sklearn.datasets    import make_blobs

from matplotlib.colors   import ListedColormap
from matplotlib.patches  import Ellipse
from lecturesupport      import plotsupport as ps

import scipy.stats       as stats
import astropy.io.fits   as fits

from keras.models        import Model
from keras.layers        import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate

%matplotlib inline


from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'png')
#plt.style.use('seaborn')
mpl.rcParams['figure.dpi'] = 300
Using TensorFlow backend.
In [2]:
import importlib
importlib.reload(ps);

Introduction

  • Introduction to neutron imaging

    • Some words about the method
    • Contrasts
  • Introduction to segmentation

    • What is segmentation
    • Noise and SNR
  • Problematic segmentation tasks

    • Intro
    • Segmentation problems in neutron imaging

What is an image?

A very abstract definition:

  • A pairing between spatial information (position)
  • and some other kind of information (value).

In most cases this is a two- or three-dimensional position (x,y,z coordinates) and a numeric value (intensity)

Science and Imaging

Images are great for qualitative analyses since our brains can quickly interpret them without large programming investements.

Proper processing and quantitative analysis is however much more difficult with images.

  • If you measure a temperature, quantitative analysis is easy, $T=50K$.
  • If you measure an image it is much more difficult and much more prone to mistakes,
    • subtle setup variations may break you analysis process,
    • and confusing analyses due to unclear problem definition

Furthermore in image processing there is a plethora of tools available

  • Thousands of algorithms available
  • Thousands of tools
  • Many images require multi-step processing
  • Experimenting is time-consuming

Some word about neutron imaging

The transmitted radiation is described by Beer-Lambert's law which in its basic form looks like

$$I=I_0\cdot{}e^{-\int_L \mu{}(x) dx}$$

Image types obtained with neutron imaging

Fundamental information Additional dimensions Derived information
2D Radiography Time series q-values
3D Tomography Spectra strain
Crystal orientation

Neutron imaging contrast


Transmission through sample X-ray attenuation Neutron attenuation

Measurements are rarely perfect

Factors affecting the image quality

  • Resolution (Imaging system transfer functions)
  • Noise
  • Contrast
  • Inhomogeneous contrast
  • Artifacts

Introduction to segmentation

Different types of segmentation

Basic segmentation: Applying a threshold to an image

Start out with a simple image of a cross with added noise

$$ I(x,y) = f(x,y) $$
In [3]:
fig,ax = plt.subplots(1,2,figsize=(12,6))
nx = 5; ny = 5;
# Create the test image
xx, yy   = np.meshgrid(np.arange(-nx, nx+1)/nx*2*np.pi, np.arange(-ny, ny+1)/ny*2*np.pi)
cross_im = 1.5*np.abs(np.cos(xx*yy))/(np.abs(xx*yy)+(3*np.pi/nx)) + np.random.uniform(-0.25, 0.25, size = xx.shape)       

# Show it
im=ax[0].imshow(cross_im, cmap = 'hot'); ax[0].set_title("Image")
ax[1].hist(cross_im.ravel(),bins=10); ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
2021-02-15T16:26:09.107798 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Applying a threshold to an image

Applying the threshold is a deceptively simple operation

$$ I(x,y) = \begin{cases} 1, & f(x,y)\geq0.40 \\ 0, & f(x,y)<0.40 \end{cases}$$
In [4]:
threshold = 0.4; thresh_img = cross_im > threshold
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0].imshow(cross_im, cmap = 'hot', extent = [xx.min(), xx.max(), yy.min(), yy.max()]); ax[0].set_title("Image")
ax[0].plot(xx[np.where(thresh_img)]*0.9, yy[np.where(thresh_img)]*0.9,
           'ks', markerfacecolor = 'green', alpha = 0.5,label = 'Threshold', markersize = 22); ax[0].legend(fontsize=12);
ax[1].hist(cross_im.ravel(),bins=10); ax[1].axvline(x=threshold,color='r',label='Threshold'); ax[1].legend(fontsize=12); 
ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
2021-02-15T16:26:10.176126 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Noise and SNR

The noise in neutron imaging mainly originates from the amount of captured neutrons.

This noise is Poisson distributed and the signal to noise ratio is

$$SNR=\frac{E[x]}{s[x]}\sim\frac{N}{\sqrt{N}}=\sqrt{N}$$

Problematic segmentation tasks

Woodland Encounter Bev Doolittle

Typical image features that makes life harder

In neutron imaging you see all these image phenomena.

Limited data problem

Different types of limited data:

  • Few data points or limited amounts of images
  • Unbalanced data
  • Little or missing training data

Training data from NI is limited

  • Long experiment times
  • Few samples
  • Some recycling from previous experiments is posible.

Augmentation to increase training data

Data augmentation is a method modify your exisiting data to obtain variations of it.

Augmentation will be used to increase the training data in the root segmenation example in the end of this lecture.

Simulation to increase training data

  • Geometric models
  • Template models
  • Physical models

Both augmented and simulated data should be combined with real data.

Transfer learning

Transfer learning is a technique that uses a pre-trained network to

  • Speed up training on your current data
  • Support in cases of limited data
  • Improve network performance

Unsupervised segmentation

Introducing clustering

In [5]:
test_pts = pd.DataFrame(make_blobs(n_samples=200, random_state=2018)[
                        0], columns=['x', 'y'])
plt.plot(test_pts.x, test_pts.y, 'r.');
2021-02-15T16:26:11.090756 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

k-means

The user only have to provide the number of classes the algorithm shall find.

Note The algorithm will find exactly the number you ask it to, it doesn't care if it makes sense!

Basic clustering example

In [6]:
N=3
fig, ax = plt.subplots(1,N,figsize=(18,4.5))

for i in range(N) :
    km = KMeans(n_clusters=i+2, random_state=2018); n_grp = km.fit_predict(test_pts)
    ax[i].scatter(test_pts.x, test_pts.y, c=n_grp)
    ax[i].set_title('{0} groups'.format(i+2))
2021-02-15T16:26:11.701497 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Add spatial information to k-means

In [7]:
orig = fits.getdata('../data/spots/mixture12_00001.fits')[::4,::4]
fig,ax = plt.subplots(1,6,figsize=(18,5)); x,y = np.meshgrid(np.linspace(0,1,orig.shape[0]),np.linspace(0,1,orig.shape[1]))
ax[0].imshow(orig, vmin=0, vmax=4000), ax[0].set_title('Original')
ax[1].imshow(x), ax[1].set_title('x-coordinates')
ax[2].imshow(y), ax[2].set_title('y-coordinates')
ax[3].imshow(flt.gaussian(orig, sigma=5)), ax[3].set_title('Weighted neighborhood')
ax[4].imshow(flt.sobel_h(orig),vmin=0, vmax=0.001),ax[4].set_title('Horizontal edges')
ax[5].imshow(flt.sobel_v(orig),vmin=0, vmax=0.001),ax[5].set_title('Vertical edges');
2021-02-15T16:26:12.964273 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

When can clustering be used on images?

  • Single images
  • Bimodal data
  • Spectrum data

Clustering applied to wavelength resolved imaging

The imaging techniques and its applications

The data

In [8]:
tof  = np.load('../data/tofdata.npy')
wtof = tof.mean(axis=2)
plt.imshow(wtof,cmap='gray'); 
plt.title('Average intensity all time bins');
2021-02-15T16:26:15.137845 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Looking at the spectra

In [9]:
fig, ax= plt.subplots(1,2,figsize=(12,5))
ax[0].imshow(wtof,cmap='gray'); ax[0].set_title('Average intensity all time bins');
ax[0].plot(57,3,'ro'), ax[0].plot(15,30,'bo'), ax[0].plot(79,90,'go'); ax[0].plot(100,120,'co');
ax[1].plot(tof[30,15,:],'b', label='Sample'); ax[1].plot(tof[3,57,:],'r', label='Background'); ax[1].plot(tof[90,79,:],'g', label='Spacer'); ax[1].legend();ax[1].plot(tof[120,100,:],'c', label='Sample 2');
2021-02-15T16:26:15.713397 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Reshaping

In [10]:
tofr=tof.reshape([tof.shape[0]*tof.shape[1],tof.shape[2]])
print("Input ToF dimensions",tof.shape)
print("Reshaped ToF data",tofr.shape)
Input ToF dimensions (128, 128, 661)
Reshaped ToF data (16384, 661)

Setting up and running k-means

  • We can clearly see that there is void on the sides of the specimens.
  • There is also a separating band between the specimens.
  • Finally we have to decide how many regions we want to find in the specimens. Let's start with two regions with different characteristics.
In [11]:
km = KMeans(n_clusters=4, random_state=2018)     # Random state is an initialization parameter for the random number generator
c  = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose()             # cluster centroid spectra

Results from the first try

In [12]:
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc);                  axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot

im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
2021-02-15T16:26:19.146840 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

We need more clusters

  • Experiment data has variations on places we didn't expect k-means to detect as clusters.
  • We need to increase the number of clusters!

Increasing the number of clusters

What happens when we increase the number of clusters to ten?

In [13]:
km = KMeans(n_clusters=10, random_state=2018)
c  = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose()             # cluster centroid spectra

Results of k-means with ten clusters

In [14]:
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='gray'); axes[0].set_title('Average image')
p=axes[1].plot(kc);                  axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot

im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
2021-02-15T16:26:27.703126 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Interpreting the clusters

In [15]:
fig,axes = plt.subplots(1,1,figsize=(14,5)); 
plt.plot(kc); axes.set_title('Cluster centroid spectra'); 
axes.add_patch(Ellipse((0,0.62), width=30,height=0.55,fill=False,color='r')) #,axes.set_aspect(tof.shape[2], adjustable='box')
axes.add_patch(Ellipse((0,0.24), width=30,height=0.15,fill=False,color='cornflowerblue')),axes.set_aspect(tof.shape[2], adjustable='box');
2021-02-15T16:26:29.342525 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Cleaning up the works space

In [16]:
del km, c, kc, tofr, tof

Supervised segmentation

  1. Training: Requires training data
  2. Verification: Requires verification data
  3. Inference: The images you want to segment

k nearest neighbors

Create example data for supervised segmentation

In [17]:
blob_data, blob_labels = make_blobs(n_samples=100, random_state=2018)
test_pts = pd.DataFrame(blob_data, columns=['x', 'y'])
test_pts['group_id'] = blob_labels
plt.scatter(test_pts.x, test_pts.y, c=test_pts.group_id, cmap='viridis');
2021-02-15T16:26:29.715051 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Detecting unwanted outliers in neutron images

In [18]:
orig= fits.getdata('../data/spots/mixture12_00001.fits')
annotated=io.imread('../data/spots/mixture12_00001.png'); mask=(annotated[:,:,1]==0)
r=600; c=600; w=256
ps.magnifyRegion(orig,[r,c,r+w,c+w],[15,7],vmin=400,vmax=4000,title='Neutron radiography')
2021-02-15T16:26:30.536333 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Marked-up spots

Baseline - Traditional spot cleaning algorithm

Parameters

  • N Width of median filter.
  • k Threshold level for outlier detection.

The spot cleaning algorithm

In [19]:
def spotCleaner(img, threshold=0.95, selem=np.ones([3,3])) :
    fimg=img.astype('float32')
    mimg = flt.median(fimg,selem=selem)
    timg = threshold < np.abs(fimg-mimg)
    cleaned = mimg * timg + fimg * (1-timg)
    return (cleaned,timg)
In [20]:
baseclean,timg = spotCleaner(orig,threshold=1000)
ps.magnifyRegion(baseclean,[r,c,r+w,c+w],[12,3],vmin=400,vmax=4000,title='Cleaned image')
ps.magnifyRegion(timg,[r,c,r+w,c+w],[12,3],vmin=0,vmax=1,title='Detection image')
2021-02-15T16:26:34.707961 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-02-15T16:26:36.176187 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

k nearest neighbors to detect spots

In [21]:
selem=np.ones([3,3])
forig=orig.astype('float32')
mimg = flt.median(forig,selem=selem)
d = np.abs(forig-mimg)

fig,ax=plt.subplots(1,1,figsize=(8,5))
h,x,y,u=ax.hist2d(forig[:1024,:].ravel(),d[:1024,:].ravel(), bins=100);
ax.imshow(np.log(h[::-1]+1),vmin=0,vmax=3,extent=[x.min(),x.max(),y.min(),y.max()])
ax.set_xlabel('Input image - $f$'),ax.set_ylabel('$|f-med_{3x3}(f)|$'),ax.set_title('Log bivariate histogram');
2021-02-15T16:26:38.161294 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Prepare data

Training data

In [22]:
trainorig = forig[:,:1000].ravel()
traind    = d[:,:1000].ravel()
trainmask = mask[:,:1000].ravel()

train_pts = pd.DataFrame({'orig': trainorig, 'd': traind, 'mask':trainmask})

Test data

In [23]:
testorig = forig[:,1000:].ravel()
testd    = d[:,1000:].ravel()
testmask = mask[:,1000:].ravel()

test_pts = pd.DataFrame({'orig': testorig, 'd': testd, 'mask':testmask})

Train the model

In [24]:
k_class = KNeighborsClassifier(1)
k_class.fit(train_pts[['orig', 'd']], train_pts['mask']) 
Out[24]:
KNeighborsClassifier(n_neighbors=1)

Inspect decision space

In [25]:
xx, yy = np.meshgrid(np.linspace(test_pts.orig.min(), test_pts.orig.max(), 100),
                     np.linspace(test_pts.d.min(), test_pts.d.max(), 100),indexing='ij');
grid_pts = pd.DataFrame(dict(x=xx.ravel(), y=yy.ravel()))
grid_pts['predicted_id'] = k_class.predict(grid_pts[['x', 'y']])
plt.scatter(grid_pts.x, grid_pts.y, c=grid_pts.predicted_id, cmap='gray'); plt.title('Testing Points'); plt.axis('square');
2021-02-15T16:26:43.174781 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Apply knn to unseen data

In [26]:
pred = k_class.predict(test_pts[['orig', 'd']])
pimg = pred.reshape(d[1000:,:].shape)
In [27]:
fig,ax = plt.subplots(1,3,figsize=(15,6))
ax[0].imshow(forig[1000:,:],vmin=0,vmax=4000), ax[0].set_title('Original image')
ax[1].imshow(pimg), ax[1].set_title('Predicted spot')
ax[2].imshow(mask[1000:,:]),ax[2].set_title('Annotated spots');
2021-02-15T16:27:38.339771 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Performance check

In [28]:
cmbase = confusion_matrix(mask[:,1000:].ravel(), timg[:,1000:].ravel(), normalize='all')
cmknn  = confusion_matrix(mask[:,1000:].ravel(), pimg.ravel(), normalize='all')
In [29]:
fig,ax = plt.subplots(1,2,figsize=(10,4))
sn.heatmap(cmbase, annot=True,ax=ax[0]), ax[0].set_title('Confusion matrix baseline');
sn.heatmap(cmknn, annot=True,ax=ax[1]), ax[1].set_title('Confusion matrix k-NN');
2021-02-15T16:27:42.075482 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Some remarks about k-nn

  • It takes more time to process
  • You need to prepare training data
    • Annotation takes time...
    • Here we used the segmentation on the same type of image
    • We should normalize the data
    • This was a raw projection, what happens if we use a flat field corrected image?
  • Finds more spots than baseline
  • Data is very unbalanced, try a selection of non-spot data for training.
    • Is it faster?
    • Is there a drop segmentation performance?

Note There are other spot detection methods that perform better than the baseline.

Clean up

In [30]:
del k_class, cmbase, cmknn

Convolutional neural networks for segmentation

In [31]:
import keras.optimizers as opt
import keras.losses as loss
import keras.metrics as metrics

Training data

We have two choices:

  1. Use real data
    • requires time consuming markup to provide training data
    • corresponds to real life images
  2. Synthesize data
    • flexible and provides both 'dirty' data and ground truth.
    • model may not behave as real data

Preparing real data

We will use the spotty image as training data for this example

Prepare training, validation, and test data

Any analysis system must be verified to be demonstrate its performance and to further optimize it.

For this we need to split our data into three categories:

  1. Training data
  2. Test data
  3. Validation data
Training Validation Test
70% 15% 15%

Build a CNN for spot detection and cleaning

We need:

  • Data
  • Tensorflow
    • Data provider
    • Model design

Build a U-Net model

In [32]:
def buildSpotUNet( base_depth = 48) :
    in_img = Input((None, None, 1), name='Image_Input')
    lay_1 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(in_img)
    lay_2 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_1)
    lay_3 = MaxPooling2D(pool_size=(2, 2))(lay_2)
    lay_4 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_3)
    lay_5 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_4)
    lay_6 = MaxPooling2D(pool_size=(2, 2))(lay_5)
    lay_7 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_6)
    lay_8 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_7)
    lay_9 = UpSampling2D((2, 2))(lay_8)
    lay_10 = concatenate([lay_5, lay_9])
    lay_11 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_10)
    lay_12 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_11)
    lay_13 = UpSampling2D((2, 2))(lay_12)
    lay_14 = concatenate([lay_2, lay_13])
    lay_15 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_14)
    lay_16 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_15)
    lay_17 = Conv2D(1, kernel_size=(1, 1), padding='same',
                    activation='relu')(lay_16)
    t_unet = Model(inputs=[in_img], outputs=[lay_17], name='SpotUNET')
    return t_unet

Model summary

In [33]:
t_unet = buildSpotUNet(base_depth=24)
t_unet.summary()
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

Model: "SpotUNET"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Image_Input (InputLayer)        (None, None, None, 1 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 2 240         Image_Input[0][0]                
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 2 5208        conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, None, None, 2 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, None, None, 4 10416       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, None, None, 4 20784       conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, None, None, 4 0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, None, None, 9 41568       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, None, None, 9 83040       conv2d_5[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, None, None, 9 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, None, None, 1 0           conv2d_4[0][0]                   
                                                                 up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, None, None, 4 62256       concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, None, None, 4 20784       conv2d_7[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, None, None, 4 0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, None, None, 7 0           conv2d_2[0][0]                   
                                                                 up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, None, None, 2 15576       concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, None, None, 2 5208        conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, None, None, 1 25          conv2d_10[0][0]                  
==================================================================================================
Total params: 265,105
Trainable params: 265,105
Non-trainable params: 0
__________________________________________________________________________________________________

Prepare data for training and validation

In [34]:
train_img,  valid_img  = forig[128:256, 500:1300], forig[500:1000, 300:1500]
train_mask, valid_mask = mask[128:256, 500:1300], mask[500:1000, 300:1500]
wpos = [600,600]; ww   = 512
forigc = forig[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
maskc  = mask[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]

# train_img, valid_img = forig[128:256, 300:1500], forig[500:, 300:1500]
# train_mask, valid_mask = mask[128:256, 300:1500], mask[500:, 300:1500]
fig, ax = plt.subplots(1, 4, figsize=(15, 6), dpi=300); ax=ax.ravel()

ax[0].imshow(train_img, cmap='bone',vmin=0,vmax=4000);ax[0].set_title('Train Image')
ax[1].imshow(train_mask, cmap='bone'); ax[1].set_title('Train Mask')

ax[2].imshow(valid_img, cmap='bone',vmin=0,vmax=4000); ax[2].set_title('Validation Image')
ax[3].imshow(valid_mask, cmap='bone');ax[3].set_title('Validation Mask');
2021-02-15T16:27:42.859594 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Functions to prepare data for training

In [35]:
def prep_img(x, n=1): 
    return (prep_mask(x, n=n)-train_img.mean())/train_img.std()


def prep_mask(x, n=1): 
    return np.stack([np.expand_dims(x, -1)]*n, 0)

Test the untrained model

  • We can make predictions with an untrained model (default parameters)
  • but we clearly do not expect them to be very good
In [36]:
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

In [37]:
fig, m_axs = plt.subplots(2, 3, figsize=(15, 6), dpi=150)
for c_ax in m_axs.ravel():
    c_ax.axis('off')
((ax1, _, ax2), (ax3, ax4, ax5)) = m_axs
ax1.imshow(train_img, cmap='bone',vmin=0,vmax=4000); ax1.set_title('Train Image')
ax2.imshow(train_mask, cmap='viridis'); ax2.set_title('Train Mask')

ax3.imshow(forigc, cmap='bone',vmin=0, vmax=4000); ax3.set_title('Test Image')
ax4.imshow(unet_pred, cmap='viridis', vmin=0, vmax=1); ax4.set_title('Predicted Segmentation')

ax5.imshow(maskc, cmap='viridis'); ax5.set_title('Ground Truth');
2021-02-15T16:27:45.290061 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Training conditions

  • Loss function - Binary cross-correlation
  • Optimizer - ADAM
  • 20 Epochs (training iterations)
  • Metrics
    1. Binary accuracy (percentage of pixels correct classified) $$BA=\frac{1}{N}\sum_i(f_i==g_i)$$
    2. Mean absolute error

Another popular metric is the Dice score $$DSC=\frac{2|X \cap Y|}{|X|+|Y|}=\frac{2\,TP}{2TP+FP+FN}$$

In [38]:
mlist = [
      metrics.TruePositives(name='tp'),        metrics.FalsePositives(name='fp'), 
      metrics.TrueNegatives(name='tn'),        metrics.FalseNegatives(name='fn'), 
      metrics.BinaryAccuracy(name='accuracy'), metrics.Precision(name='precision'),
      metrics.Recall(name='recall'),           metrics.AUC(name='auc'),
      metrics.MeanAbsoluteError(name='mae')]

t_unet.compile(
    loss=loss.BinaryCrossentropy(),  # we use the binary cross-entropy to optimize
    optimizer=opt.Adam(lr=1e-3),     # we use ADAM to optimize
    metrics=mlist                    # we keep track of the metrics in mlist
)
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3172: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

A general note on the following demo

This is a very bad way to train a model;

  • the loss function is poorly chosen,
  • the optimizer can be improved the learning rate can be changed,
  • the training and validation data should not come from the same sample (and definitely not the same measurement).

The goal is to be aware of these techniques and have a feeling for how they can work for complex problems

Training the spot detection model

In [39]:
loss_history = t_unet.fit(prep_img(train_img, n=3),
                          prep_mask(train_mask, n=3),
                          validation_data=(prep_img(valid_img),
                                           prep_mask(valid_mask)),
                          epochs=20,
                          verbose = 1)
Train on 3 samples, validate on 1 samples
Epoch 1/20
3/3 [==============================] - 10s 3s/step - loss: 0.1002 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.4782 - mae: 0.0282 - val_loss: 0.1482 - val_tp: 1.0000 - val_fp: 1.0000 - val_tn: 593515.0000 - val_fn: 6483.0000 - val_accuracy: 0.9892 - val_precision: 0.5000 - val_recall: 1.5423e-04 - val_auc: 0.5501 - val_mae: 0.0109
Epoch 2/20
3/3 [==============================] - 7s 2s/step - loss: 0.1093 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.5701 - mae: 0.0084 - val_loss: 0.0742 - val_tp: 3.0000 - val_fp: 2.0000 - val_tn: 593514.0000 - val_fn: 6481.0000 - val_accuracy: 0.9892 - val_precision: 0.6000 - val_recall: 4.6268e-04 - val_auc: 0.6937 - val_mae: 0.0185
Epoch 3/20
3/3 [==============================] - 7s 2s/step - loss: 0.0567 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.6901 - mae: 0.0202 - val_loss: 0.0685 - val_tp: 16.0000 - val_fp: 12.0000 - val_tn: 593504.0000 - val_fn: 6468.0000 - val_accuracy: 0.9892 - val_precision: 0.5714 - val_recall: 0.0025 - val_auc: 0.7284 - val_mae: 0.0451
Epoch 4/20
3/3 [==============================] - 7s 2s/step - loss: 0.0598 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7722 - mae: 0.0417 - val_loss: 0.0554 - val_tp: 27.0000 - val_fp: 15.0000 - val_tn: 593501.0000 - val_fn: 6457.0000 - val_accuracy: 0.9892 - val_precision: 0.6429 - val_recall: 0.0042 - val_auc: 0.8299 - val_mae: 0.0297
Epoch 5/20
3/3 [==============================] - 7s 2s/step - loss: 0.0457 - tp: 6.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2538.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0024 - auc: 0.8643 - mae: 0.0262 - val_loss: 0.0867 - val_tp: 19.0000 - val_fp: 14.0000 - val_tn: 593502.0000 - val_fn: 6465.0000 - val_accuracy: 0.9892 - val_precision: 0.5758 - val_recall: 0.0029 - val_auc: 0.7730 - val_mae: 0.0119
Epoch 6/20
3/3 [==============================] - 7s 2s/step - loss: 0.0615 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7931 - mae: 0.0095 - val_loss: 0.0668 - val_tp: 33.0000 - val_fp: 19.0000 - val_tn: 593497.0000 - val_fn: 6451.0000 - val_accuracy: 0.9892 - val_precision: 0.6346 - val_recall: 0.0051 - val_auc: 0.8114 - val_mae: 0.0150
Epoch 7/20
3/3 [==============================] - 7s 2s/step - loss: 0.0474 - tp: 9.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2535.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0035 - auc: 0.8397 - mae: 0.0131 - val_loss: 0.0722 - val_tp: 83.0000 - val_fp: 68.0000 - val_tn: 593448.0000 - val_fn: 6401.0000 - val_accuracy: 0.9892 - val_precision: 0.5497 - val_recall: 0.0128 - val_auc: 0.7762 - val_mae: 0.0520
Epoch 8/20
3/3 [==============================] - 7s 2s/step - loss: 0.0619 - tp: 54.0000 - fp: 30.0000 - tn: 304626.0000 - fn: 2490.0000 - accuracy: 0.9918 - precision: 0.6429 - recall: 0.0212 - auc: 0.8104 - mae: 0.0461 - val_loss: 0.0800 - val_tp: 104.0000 - val_fp: 90.0000 - val_tn: 593426.0000 - val_fn: 6380.0000 - val_accuracy: 0.9892 - val_precision: 0.5361 - val_recall: 0.0160 - val_auc: 0.7575 - val_mae: 0.0599
Epoch 9/20
3/3 [==============================] - 7s 2s/step - loss: 0.0663 - tp: 75.0000 - fp: 45.0000 - tn: 304611.0000 - fn: 2469.0000 - accuracy: 0.9918 - precision: 0.6250 - recall: 0.0295 - auc: 0.7753 - mae: 0.0502 - val_loss: 0.0736 - val_tp: 97.0000 - val_fp: 71.0000 - val_tn: 593445.0000 - val_fn: 6387.0000 - val_accuracy: 0.9892 - val_precision: 0.5774 - val_recall: 0.0150 - val_auc: 0.7597 - val_mae: 0.0514
Epoch 10/20
3/3 [==============================] - 7s 2s/step - loss: 0.0581 - tp: 69.0000 - fp: 36.0000 - tn: 304620.0000 - fn: 2475.0000 - accuracy: 0.9918 - precision: 0.6571 - recall: 0.0271 - auc: 0.7775 - mae: 0.0406 - val_loss: 0.0700 - val_tp: 91.0000 - val_fp: 53.0000 - val_tn: 593463.0000 - val_fn: 6393.0000 - val_accuracy: 0.9892 - val_precision: 0.6319 - val_recall: 0.0140 - val_auc: 0.7565 - val_mae: 0.0426
Epoch 11/20
3/3 [==============================] - 7s 2s/step - loss: 0.0560 - tp: 57.0000 - fp: 27.0000 - tn: 304629.0000 - fn: 2487.0000 - accuracy: 0.9918 - precision: 0.6786 - recall: 0.0224 - auc: 0.7622 - mae: 0.0334 - val_loss: 0.0652 - val_tp: 77.0000 - val_fp: 37.0000 - val_tn: 593479.0000 - val_fn: 6407.0000 - val_accuracy: 0.9892 - val_precision: 0.6754 - val_recall: 0.0119 - val_auc: 0.7681 - val_mae: 0.0345
Epoch 12/20
3/3 [==============================] - 7s 2s/step - loss: 0.0535 - tp: 54.0000 - fp: 21.0000 - tn: 304635.0000 - fn: 2490.0000 - accuracy: 0.9918 - precision: 0.7200 - recall: 0.0212 - auc: 0.7682 - mae: 0.0271 - val_loss: 0.0579 - val_tp: 70.0000 - val_fp: 30.0000 - val_tn: 593486.0000 - val_fn: 6414.0000 - val_accuracy: 0.9893 - val_precision: 0.7000 - val_recall: 0.0108 - val_auc: 0.8066 - val_mae: 0.0269
Epoch 13/20
3/3 [==============================] - 7s 2s/step - loss: 0.0463 - tp: 48.0000 - fp: 12.0000 - tn: 304644.0000 - fn: 2496.0000 - accuracy: 0.9918 - precision: 0.8000 - recall: 0.0189 - auc: 0.8154 - mae: 0.0211 - val_loss: 0.0502 - val_tp: 66.0000 - val_fp: 24.0000 - val_tn: 593492.0000 - val_fn: 6418.0000 - val_accuracy: 0.9893 - val_precision: 0.7333 - val_recall: 0.0102 - val_auc: 0.8760 - val_mae: 0.0193
Epoch 14/20
3/3 [==============================] - 7s 2s/step - loss: 0.0380 - tp: 39.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2505.0000 - accuracy: 0.9918 - precision: 0.8125 - recall: 0.0153 - auc: 0.8912 - mae: 0.0152 - val_loss: 0.0476 - val_tp: 66.0000 - val_fp: 23.0000 - val_tn: 593493.0000 - val_fn: 6418.0000 - val_accuracy: 0.9893 - val_precision: 0.7416 - val_recall: 0.0102 - val_auc: 0.9248 - val_mae: 0.0126
Epoch 15/20
3/3 [==============================] - 7s 2s/step - loss: 0.0345 - tp: 36.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2508.0000 - accuracy: 0.9918 - precision: 0.8000 - recall: 0.0142 - auc: 0.9336 - mae: 0.0102 - val_loss: 0.0454 - val_tp: 74.0000 - val_fp: 23.0000 - val_tn: 593493.0000 - val_fn: 6410.0000 - val_accuracy: 0.9893 - val_precision: 0.7629 - val_recall: 0.0114 - val_auc: 0.9276 - val_mae: 0.0138
Epoch 16/20
3/3 [==============================] - 7s 2s/step - loss: 0.0339 - tp: 45.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2499.0000 - accuracy: 0.9918 - precision: 0.7500 - recall: 0.0177 - auc: 0.9376 - mae: 0.0125 - val_loss: 0.0434 - val_tp: 80.0000 - val_fp: 24.0000 - val_tn: 593492.0000 - val_fn: 6404.0000 - val_accuracy: 0.9893 - val_precision: 0.7692 - val_recall: 0.0123 - val_auc: 0.9342 - val_mae: 0.0153
Epoch 17/20
3/3 [==============================] - 7s 2s/step - loss: 0.0337 - tp: 51.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2493.0000 - accuracy: 0.9918 - precision: 0.7727 - recall: 0.0200 - auc: 0.9390 - mae: 0.0145 - val_loss: 0.0430 - val_tp: 92.0000 - val_fp: 29.0000 - val_tn: 593487.0000 - val_fn: 6392.0000 - val_accuracy: 0.9893 - val_precision: 0.7603 - val_recall: 0.0142 - val_auc: 0.9405 - val_mae: 0.0194
Epoch 18/20
3/3 [==============================] - 7s 2s/step - loss: 0.0339 - tp: 66.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2478.0000 - accuracy: 0.9919 - precision: 0.8148 - recall: 0.0259 - auc: 0.9514 - mae: 0.0174 - val_loss: 0.0414 - val_tp: 96.0000 - val_fp: 26.0000 - val_tn: 593490.0000 - val_fn: 6388.0000 - val_accuracy: 0.9893 - val_precision: 0.7869 - val_recall: 0.0148 - val_auc: 0.9458 - val_mae: 0.0172
Epoch 19/20
3/3 [==============================] - 7s 2s/step - loss: 0.0318 - tp: 66.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2478.0000 - accuracy: 0.9919 - precision: 0.8148 - recall: 0.0259 - auc: 0.9557 - mae: 0.0144 - val_loss: 0.0409 - val_tp: 96.0000 - val_fp: 23.0000 - val_tn: 593493.0000 - val_fn: 6388.0000 - val_accuracy: 0.9893 - val_precision: 0.8067 - val_recall: 0.0148 - val_auc: 0.9488 - val_mae: 0.0132
Epoch 20/20
3/3 [==============================] - 7s 2s/step - loss: 0.0303 - tp: 63.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2481.0000 - accuracy: 0.9919 - precision: 0.8077 - recall: 0.0248 - auc: 0.9589 - mae: 0.0107 - val_loss: 0.0447 - val_tp: 96.0000 - val_fp: 23.0000 - val_tn: 593493.0000 - val_fn: 6388.0000 - val_accuracy: 0.9893 - val_precision: 0.8067 - val_recall: 0.0148 - val_auc: 0.9343 - val_mae: 0.0117

Training history plots

In [40]:
titleDict = {'tp': "True Positives",'fp': "False Positives",'tn': "True Negatives",'fn': "False Negatives", 'accuracy':"BinaryAccuracy",'precision': "Precision",'recall':"Recall",'auc': "Area under Curve", 'mae': "Mean absolute error"}

fig,ax = plt.subplots(2,5, figsize=(20,8), dpi=300)
ax =ax.ravel()
for idx,key in enumerate(titleDict.keys()): 
    ax[idx].plot(loss_history.epoch, loss_history.history[key], color='coral', label='Training')
    ax[idx].plot(loss_history.epoch, loss_history.history['val_'+key], color='cornflowerblue', label='Validation')
    ax[idx].set_title(titleDict[key]); 

ax[9].axis('off');
axLine, axLabel = ax[0].get_legend_handles_labels() # Take the lables and plot line information from the first panel
lines =[]; labels = []; lines.extend(axLine); labels.extend(axLabel);fig.legend(lines, labels, bbox_to_anchor=(0.7, 0.3), loc='upper left');
2021-02-15T16:30:10.399077 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Prediction on the training data

In [41]:
unet_train_pred = t_unet.predict(prep_img(train_img[:,wpos[1]:(wpos[1]+ww)]))[0, :, :, 0]

fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs= m_axs.ravel(); 
for c_ax in m_axs: c_ax.axis('off')

m_axs[0].imshow(train_img[:,wpos[1]:(wpos[1]+ww)], cmap='bone', vmin=0, vmax=4000), m_axs[0].set_title('Train Image')
m_axs[1].imshow(unet_train_pred, cmap='viridis', vmin=0, vmax=0.2), m_axs[1].set_title('Predicted Training')
m_axs[2].imshow(train_mask[:,wpos[1]:(wpos[1]+ww)], cmap='viridis'), m_axs[2].set_title('Train Mask');
2021-02-15T16:30:12.495650 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Prediction using unseen data

In [42]:
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]

fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs = m_axs.ravel() ; 
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(forigc, cmap='bone', vmin=0, vmax=4000); m_axs[0].set_title('Full Image')
f1=m_axs[1].imshow(unet_pred, cmap='viridis', vmin=0, vmax=0.1); m_axs[1].set_title('Predicted Segmentation'); fig.colorbar(f1,ax=m_axs[1]);
m_axs[2].imshow(maskc,cmap='viridis'); m_axs[2].set_title('Ground Truth');
2021-02-15T16:30:13.904923 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Converting predictions to segments

In [43]:
fig, ax = plt.subplots(1,2, figsize=(12,4))
ax0=ax[0].imshow(unet_pred, vmin=0, vmax=0.1); ax[0].set_title('Predicted segmentation'); fig.colorbar(ax0,ax=ax[0])
ax[1].imshow(0.05<unet_pred), ax[1].set_title('Final segmenation');
2021-02-15T16:30:14.807851 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Hit cases

In [44]:
gt = maskc
pr = 0.05<unet_pred
ps.showHitCases(gt,pr,cmap='gray')
2021-02-15T16:30:16.398024 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Hit map

In [45]:
fig, ax = plt.subplots(1,2,figsize=(12,4))

ps.showHitMap(gt,pr,ax=ax)
2021-02-15T16:30:18.146410 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Concluding remarks about the spot detection

Segmenting root networks in the rhizosphere using an U-Net

Background

  • Soil and in particular the rhizosphere are of central interest for neutron imaging users.
  • The experiments aim to follow the water distribution near the roots.
  • The roots must be identified in 2D and 3D data
  • Today: much of this mark-up is done manually!

Available data

Considered NN models

Loss functions

Training

Results

Summary

Future Machine learning challenges in neutron imaging

Concluding remarks

In [ ]: